/*
 *  WekaLearning.java
 *
 *  Yaoyong Li 22/03/2007
 *
 *  $Id: WekaLearning.java, v 1.0 2007-03-22 12:58:16 +0000 yaoyong $
 */
package gate.learning.learners.weka;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SparseInstance;
import gate.learning.ConstantParameters;
import gate.learning.LabelsOfFV;
import gate.learning.LabelsOfFeatureVectorDoc;
import gate.learning.NLPFeaturesList;
import gate.learning.SparseFeatureVector;
import gate.learning.learners.MultiClassLearning;
import gate.util.BomStrippingInputStreamReader;

/**
 * The interface between the Weka learner and the data defined in the ML Api,
 * which convert the data into the format a Weka learner can use.
 */
public class WekaLearning {

    /** The data in the Weka object for training or application. */
    public Instances instancesData;
    /** The labels in the form of instances of every doc. */
    public LabelsOfFeatureVectorDoc[] labelsFVDoc = null;
    /** For using the feature vector data. */
    public final static short SPARSEFVDATA = 2;
    /** For using the NLP feature data. */
    public final static short NLPFEATUREFVDATA = 1;

    /** Learn a model and save it into the model file. */
    public void train(WekaLearner wekaCl, File modelFile) {
        // Training.
        wekaCl.training(instancesData);
        // Write the learner class into the modelfile by class serialisation
        try {
            if (modelFile.exists()) {
                deleteRecursively(modelFile);
            }
            FileOutputStream modelOutFile = new FileOutputStream(modelFile);
            ObjectOutputStream modelOutputObjectFile = new ObjectOutputStream(
                    modelOutFile);
            modelOutputObjectFile.writeObject(wekaCl);
            modelOutputObjectFile.flush();
            modelOutputObjectFile.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * Recursively delete a file or directory (think "rm -rf file").
     */
    private void deleteRecursively(File file) throws IOException {
        if (!file.exists()) {
            return;
        }
        if (file.isDirectory()) {
            for (File f : file.listFiles()) {
                deleteRecursively(f);
            }
        }
        if (!file.delete()) {
            throw new IOException("Couldn't delete file " + file);
        }
    }

    /** Read the model from the file and apply it to the data. */
    public void apply(WekaLearner wekaCl, File modelFile,
            boolean distributionOutput) {
        // Read the learner class from the modelfile by class serialisation
        try {
            FileInputStream modelInFile = new FileInputStream(modelFile);
            ObjectInputStream modelInputObjectFile = new ObjectInputStream(
                    modelInFile);
            wekaCl = (WekaLearner) modelInputObjectFile.readObject();
            modelInputObjectFile.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        // Apply the model to the data.
        wekaCl.applying(instancesData, labelsFVDoc, distributionOutput);
    }

    /**
     * Read the sparse feature vector data from the data file and convert it into
     * the Weka's instance format.
     */
    public void readSparseFVsFromFile(File dataFile, int numDocs,
            boolean trainingMode, int numLabels, boolean surroundMode) {
        int numFeats = 0;
        int numClasses = 0;
        labelsFVDoc = new LabelsOfFeatureVectorDoc[numDocs];
        // Read the sparse FVs by using the method in MultiClassLearning class
        MultiClassLearning multiClassL = new MultiClassLearning();
        boolean isUsingDataFile = false;
        File tempFVDataFile = null;
        multiClassL.getDataFromFile(numDocs, dataFile, isUsingDataFile, tempFVDataFile);
        // Create the attributes.
        numFeats = multiClassL.dataFVinDoc.getTotalNumFeatures();
        FastVector attributes = new FastVector(numFeats + 1);
        for (int i = 0; i < numFeats; ++i) {
            attributes.addElement(new Attribute(new Integer(i + 1).toString()));
        }
        // Add class attribute.
        if (surroundMode) {
            numClasses = 2 * numLabels + 1; // count the null too, as value -1.
        } else {
            numClasses = numLabels + 1;
        }
        FastVector classValues = new FastVector(numClasses);
        classValues.addElement("-1"); // The first class for null class
        for (int i = 1; i < numClasses; ++i) {
            classValues.addElement(new Integer(i).toString());
        }
        attributes.addElement(new Attribute("Class", classValues));
        // Create the dataset with capacity of all FVs (but actuall number of FVs
        // mabe be larger than the pre-specified, because possible multi-label) and
        // set index of class
        instancesData = new Instances("SparseFVsData", attributes,
                multiClassL.dataFVinDoc.getNumTraining());
        instancesData.setClassIndex(instancesData.numAttributes() - 1);
        // Copy the data into the instance;
        for (int iDoc = 0; iDoc < multiClassL.dataFVinDoc.getNumTrainingDocs(); ++iDoc) {
            SparseFeatureVector[] fvs = multiClassL.dataFVinDoc.trainingFVinDoc[iDoc].getFvs();
            labelsFVDoc[iDoc] = new LabelsOfFeatureVectorDoc();
            labelsFVDoc[iDoc].multiLabels = multiClassL.dataFVinDoc.labelsFVDoc[iDoc].multiLabels;
            for (int i = 0; i < fvs.length; ++i) {
                // Object valueO = fvs[i].getValues();
                double[] values = new double[fvs[i].getLen()];
                int[] indexes = new int[fvs[i].getLen()];
                for (int j = 0; j < fvs[i].getLen(); ++j) {
                    //values[j] = (double)fvs[i].values[j];
                    values[j] = fvs[i].nodes[j].value;
                    indexes[j] = fvs[i].nodes[j].index;
                }
                SparseInstance inst = new SparseInstance(1.0, values, indexes, 50000);
                inst.setDataset(instancesData);
                if (trainingMode && labelsFVDoc[iDoc].multiLabels[i].num > 0) {
                    for (int j1 = 0; j1 < labelsFVDoc[iDoc].multiLabels[i].num; ++j1) {
                        inst.setClassValue((labelsFVDoc[iDoc].multiLabels[i].labels[j1])); // label
                        // >0
                        instancesData.add(inst);
                    }
                } else {
                    inst.setClassValue("-1"); // set label as -1 for null
                    instancesData.add(inst);
                }
            }
        }
        return;
    }

    /**
     * Read the NLP feature data from the data file and convert it into the Weka's
     * instance format.
     */
    public void readNLPFeaturesFromFile(File dataFile, int numDocs,
            NLPFeaturesList nlpFeatList, boolean trainingMode, int numLabels,
            boolean surroundMode) {
        labelsFVDoc = new LabelsOfFeatureVectorDoc[numDocs];
        try {
            BufferedReader inData;
            inData = new BomStrippingInputStreamReader(new FileInputStream(
                    dataFile), "UTF-8");
            // Get the number of attributes in the data
            String[] items = inData.readLine().split(ConstantParameters.ITEMSEPARATOR);
            HashMap metaFeats = new HashMap();
            int numFeats = 0;
            // Create an attribute for each meta feature
            HashMap entityToPosition = new HashMap();
            String entityTerm = "";
            int numEntity = 0;
            // Not include the class attribute
            for (int i = 1; i < items.length; ++i) {
                // Assume the name of NGRAM should end with "gram"!!
                if (!items[i].endsWith("gram")) {
                    if (!metaFeats.containsKey(items[i])) {
                        metaFeats.put(items[i], new HashSet());
                        ++numFeats; // counted as a new attribute
                    }
                    String feat = items[i].substring(0, items[i].lastIndexOf("("));
                    String featNum = items[i].substring(items[i].lastIndexOf("("));
                    if (!feat.equals(entityTerm)) {
                        numEntity = 0;
                        entityTerm = feat;
                    } else {
                        ++numEntity;
                    }
                    entityToPosition.put(feat + "_" + numEntity, featNum);
                    if (!metaFeats.containsKey(feat)) {
                        metaFeats.put(feat, new HashSet());
                        // just for collect the terms
                    }
                }
            }
            List allTerms = new ArrayList(nlpFeatList.featuresList.keySet());
            Collections.sort(allTerms);
            for (int i = 0; i < allTerms.size(); ++i) {
                String feat = allTerms.get(i).toString();
                if (isNgramFeat(feat)) {
                    ++numFeats;
                } else {
                    feat = feat.substring(feat.indexOf("_") + 1);
                    // Name of the entity
                    String feat1 = feat.substring(0, feat.indexOf("_"));
                    // Term itself
                    String feat2 = feat.substring(feat.indexOf("_") + 1);
                    ((HashSet) metaFeats.get(feat1)).add(feat2);
                }
            }
            numFeats += 1; // include the class feature
            // Create the attributes.
            HashMap featToAttr = new HashMap(); // feat to attribute index
            FastVector attributes = new FastVector(numFeats);
            // First for the meta feature attribute.
            List metaFeatTerms = new ArrayList(metaFeats.keySet());
            int numMetaFeats = 0;
            for (int i = 0; i < metaFeatTerms.size(); ++i) {
                String featName = metaFeatTerms.get(i).toString();
                if (featName.endsWith(")")) {
                    String featName0 = featName.substring(0, featName.lastIndexOf("("));
                    HashSet metaF = (HashSet) metaFeats.get(featName0);
                    FastVector featFV = new FastVector(metaF.size());
                    for (Object obj : metaF) {
                        featFV.addElement(obj.toString());
                    }
                    attributes.addElement(new Attribute(featName, featFV));
                    featToAttr.put(featName, new Integer(numMetaFeats));
                    ++numMetaFeats;
                }
            }
            // Then the terms from ngram features
            for (int i = 0; i < allTerms.size(); ++i) {
                String feat = allTerms.get(i).toString();
                if (isNgramFeat(feat)) {
                    FastVector featFV = new FastVector(1);
                    featFV.addElement(feat);
                    attributes.addElement(new Attribute(feat, featFV));// Nominal form
                    featToAttr.put(feat, new Integer(i + numMetaFeats));
                }
            }
            // Add class attribute.
            int numClasses;
            if (surroundMode) {
                numClasses = 2 * numLabels + 1; // count the null too, as value -1.
            } else {
                numClasses = numLabels + 1;
            }
            FastVector classValues = new FastVector(numClasses);
            classValues.addElement("-1"); // The first class for null class
            for (int i = 1; i < numClasses; ++i) {
                classValues.addElement(new Integer(i).toString());
            }
            attributes.addElement(new Attribute("Class", classValues));
            // Create the dataset with capacity of all FVs, and set index of class
            instancesData = new Instances("NLPFeatureData", attributes, numDocs * 10);
            // The first attribute is for class.
            instancesData.setClassIndex(attributes.size() - 1);
            // Read data from file and copy the data into the instance;
            for (int iDoc = 0; iDoc < numDocs; ++iDoc) { // For each document
                items = inData.readLine().split(ConstantParameters.ITEMSEPARATOR);
                // The third item is for number of instances in the doc.
                int num = Integer.parseInt(items[2]);
                labelsFVDoc[iDoc] = new LabelsOfFeatureVectorDoc();
                labelsFVDoc[iDoc].multiLabels = new LabelsOfFV[num];
                for (int i = 0; i < num; ++i) { // For each instance
                    items = inData.readLine().split(ConstantParameters.ITEMSEPARATOR);
                    Instance inst = new Instance(numFeats);
                    inst.setDataset(instancesData);
                    int numLabel = Integer.parseInt(items[0]); // number of labels for
                    // the instance
                    entityTerm = "";
                    numEntity = 0;
                    // For each NLP feature term
                    for (int j = numLabel + 1; j < items.length; ++j) {
                        // Skip the feature if it is not in the list
                        if (!allTerms.contains(items[j])) {
                            continue;
                        }
                        if (isNgramFeat(items[j])) {// if it's a ngram
                            items[j] = items[j].substring(0, items[j].lastIndexOf(NLPFeaturesList.SYMBOLNGARM));
                            inst.setValue(Integer.parseInt(featToAttr.get(items[j]).toString()), items[j]);
                        } else {// if not a ngram
                            // For real features, not "_NA"
                            if (!items[j].equals(ConstantParameters.NAMENONFEATURE)) {
                                // Get the feature term
                                items[j] = items[j].substring(items[j].indexOf("_") + 1);
                                // Entity name
                                String feat1 = items[j].substring(0, items[j].indexOf("_"));
                                // Feature name
                                String feat2 = items[j].substring(items[j].indexOf("_") + 1);
                                if (!feat1.equals(entityTerm)) {
                                    numEntity = 0;
                                    entityTerm = feat1;
                                } else {
                                    ++numEntity;
                                }
                                feat1 = feat1
                                        + entityToPosition.get(feat1 + "_" + numEntity).toString();
                                inst.setValue(Integer.parseInt(featToAttr.get(feat1).toString()), feat2);
                            }
                        }
                    }
                    if (trainingMode && numLabel > 0) {
                        labelsFVDoc[iDoc].multiLabels[i] = new LabelsOfFV(numLabel);
                        for (int j = 1; j <= numLabel; ++j) {
                            inst.setClassValue(items[j]);
                            instancesData.add(inst);
                        }
                    } else {
                        labelsFVDoc[iDoc].multiLabels[i] = new LabelsOfFV(0);
                        inst.setClassValue("-1"); // set as null class
                        instancesData.add(inst);
                    }
                }// end of the loop i
            }
            inData.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return;
    }

    /** Check if the item is a n-gram or not. */
    private boolean isNgramFeat(String item) {
        if (item.contains(NLPFeaturesList.SYMBOLNGARM)) {
            return true;
        } else {
            return false;
        }
    }

    /**
     * Determining a learner from Weka using the NLP. feature data or the feature
     * vector data.
     */
    public static short obtainWekaLeanerDataType(String learnerName) {
        if (learnerName.contains("C4.5") || learnerName.contains("NaiveBayes")) {
            return NLPFEATUREFVDATA;
        } else {
            return SPARSEFVDATA;
        }
    }

    /** Obtaining the Weka learners. */
    public static WekaLearner obtainWekaLearner(String learnerName,
            String learningOpts) {
        WekaLearner wekaL = null;
        if (learnerName.contains("KNN")) {
            if (learningOpts != null) {
                wekaL = new KNNIBK(learningOpts);
            } else {
                wekaL = new KNNIBK();
            }
        } else if (learnerName.contains("NaiveBayes")) {
            wekaL = new NaiveBayesC();
        } else if (learnerName.contains("C4.5")) {
            wekaL = new C45();
        }
        if (learningOpts != null) {
            wekaL.getParametersFromOptionsLine(learningOpts);
        }
        return wekaL;
    }

    /** Determing the output type of a Weka learner. */
    public static boolean obtainWekaLearnerOutputType(String learnerName) {
        /*
         * if(learnerName.contains("KNN")) { return true; } else
         * if(learnerName.contains("NaiveBayes")) { return true; } else
         * if(learnerName.contains("C45")) { return true; }
         */
        return true;
        // return false;
    }
}
